from collections import defaultdict
from transformers import BertTokenizer, BertModel
import torch
import os
import numpy as np

torch.cuda.set_device(1)


class Bert:
    """ load plm bert model and encode txts
    for example:
        bert = Bert()
        embeddings = bert.embedding("content")
    """

    def __init__(self, plm_dir="/home/public/projects/emotion_dan/dataset/chinese_wwm_pytorch", using_gpu=True):
        self.plm_dir = plm_dir
        self.embed_size = 768
        device = "cpu"
        if using_gpu:
            device = 'cuda' if torch.cuda.is_available() else 'cpu'
        self.device = torch.device(device)
        self.tokenizer, self.model = self.load_bert_plm()

    def load_bert_plm(self):
        """
        load pre-training model
        chinese bert pre-train : https://www.cnblogs.com/think90/p/13091705.html
        if data dir is name: such as hfl/chinese-roberta-wwm-ext
        will download from the hugging-face models
        you can download local and set local dir : **/PLM/chinese_wwm_pytorch
                                                **/PLM/chinese_roberta_wwm_ext_pytorch
        :return: tokenizer, model
        """
        tokenizer = BertTokenizer.from_pretrained(self.plm_dir)
        model = BertModel.from_pretrained(self.plm_dir,
                                          output_hidden_states=True,
                                          output_attentions=True).to(self.device)
        return tokenizer, model

    def embedding(self, words):
        emb = torch.tensor([self.tokenizer.encode(words)]).to(self.device)[:, :64]
        all_hidden_states, all_attentions = self.model(emb)[-2:]
        rep = (all_hidden_states[-2][0] * all_attentions[-2][0].
               mean(dim=0).mean(dim=0).view(-1, 1)).sum(dim=0)
        return rep.cpu().detach().numpy()


if __name__ == '__main__':
    bert = Bert(using_gpu=False)
    embeddings = bert.embedding("北京是个好地方")
    print(embeddings)
    print(type(embeddings))
